from tqdm import tqdm


def train_thumos(net, config, loader_iter, optimizer, criterion, logger, step):
    net.train()

    total_loss = {}
    total_cost = []

    optimizer.zero_grad()

    for _b in range(config.batch_size):

        _, _data, _label, _point_anno, stored_info, _, _ = next(loader_iter)

        _data = _data.cuda()
        _label = _label.cuda()
        _point_anno = _point_anno.cuda()

        vid_score, cas_sigmoid_fuse, features, output = net(_data, _label)
            
        cost, loss = criterion(vid_score, cas_sigmoid_fuse, output, features, stored_info, _label, _point_anno, step)

        total_cost.append(cost)

        for key in loss.keys():
            if not (key in total_loss):
                total_loss[key] = []

            if loss[key] > 0:
                total_loss[key] += [loss[key].detach().cpu().item()]
            else:
                total_loss[key] += [loss[key]]
    
    total_cost = sum(total_cost) / config.batch_size

    total_cost.backward()
    optimizer.step()

    for key in total_loss.keys():
        logger.log_value("loss/" + key, sum(total_loss[key]) / config.batch_size, step)


def train_activity(net, config, train_loader, optimizer, criterion, logger, step):
    net.train()

    total_loss = {}
    total_cost = []

    # optimizer.zero_grad()

    load_iter = iter(train_loader)
    for i in tqdm(range(len(train_loader)), dynamic_ncols = True, desc='{}/{},training----'.format(step, config.num_iters), position=0):

        _, _data, _label, _point_anno, stored_info, vid_name, _ = next(load_iter)

        _data = _data.cuda()
        _label = _label.cuda()
        _point_anno = _point_anno.cuda()

        vid_score, cas_sigmoid_fuse, features, output = net(_data, _label)

        cost, loss = criterion(vid_name, cas_sigmoid_fuse, output, features, stored_info, _label, _point_anno, step)

        total_cost.append(cost)

        for key in loss.keys():
            if not (key in total_loss):
                total_loss[key] = []

            if loss[key] > 0:
                total_loss[key] += [loss[key].detach().cpu().item()]
            else:
                total_loss[key] += [loss[key]]

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

